import gc
import os
import shutil
import sys
import time
import warnings
import numpy as np
import torch
from torch import nn, optim
import math
import json
import random
import scipy.io as sio
from torch.nn import functional as F
from scipy.io import savemat
import pandas as pd
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchvision
from data.dataloader import build_dataloader
import torchvision.models as torchvision_models
from torchvision import models, datasets, transforms
from utils import dist
from torch import distributed as tdist
import torchvision.transforms as transforms
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
from PIL import Image
from torchvision import transforms
from torchvision.transforms import InterpolationMode, transforms

import config
from utils.util import Logger, LossManager, Pack
from data import dataloader
from model.vqvae import VQVAE
from metric.metric import PSNR, LPIPS, SSIM
from cleanfid import fid

def normalize_01_into_pm1(x):  # normalize x from [0, 1] to [-1, 1] by (x*2) - 1
    return x.add(x).add_(-1)

def denormalize_pm1_into_01(x):  # denormalize from [-1, 1] to [0, 1]
    return x.add(1).mul_(0.5)

def get_transform(mid_reso=1.125):
    data_load_reso = 256
    final_reso = data_load_reso
    mid_reso = round(mid_reso * final_reso)  # first resize to mid_reso, then crop to final_reso
    val_aug =  [
        transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
        transforms.CenterCrop((final_reso, final_reso)),
        transforms.ToTensor(), normalize_01_into_pm1,
    ]
    val_aug = transforms.Compose(val_aug)
    return val_aug

def load_dataset(data_path, batch_size=16):
    transform = get_transform()
    dataset = datasets.ImageFolder(root=data_path, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=32)
    return dataloader

def save_transformed_image(org_image, original_image_path, save_dir, prefix="ORG"):
    transformed_image = org_image.squeeze(0).permute(1, 2, 0).cpu()  # Remove batch dim and reorder channels
    transformed_image = denormalize_pm1_into_01(transformed_image).clamp(0, 1)  # Denormalize and clamp to [0, 1]
    transformed_pil_image = Image.fromarray((transformed_image.numpy() * 255).astype('uint8'))

    filename = os.path.basename(original_image_path)
    save_path = os.path.join(save_dir,  prefix+filename)
    os.makedirs(save_dir, exist_ok=True)
    transformed_pil_image.save(save_path)

def main_worker(args, epoch):
    ### create directionary
    rec_path  = os.path.join(args.rec_image_dir, args.rec_name)
    epoch_path = os.path.join(rec_path, "epoch-"+str(epoch))
    rec_epoch_path = os.path.join(epoch_path, "Rec")
    org_epoch_path = os.path.join(epoch_path, "Org")

    os.makedirs(epoch_path, exist_ok=True)
    os.makedirs(rec_epoch_path, exist_ok=True)
    os.makedirs(org_epoch_path, exist_ok=True)

    torch.cuda.set_device(dist.get_local_rank())
    model = VQVAE(args)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.cuda(dist.get_local_rank())
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_local_rank()], find_unused_parameters=True, broadcast_buffers=True)

    checkpoint_name = 'checkpoint-'+args.saver_name_pre+'-'+str(epoch)+'.pth.tar'
    checkpoint_path = os.path.join(args.checkpoint_dir, checkpoint_name)
    loc = 'cuda:{}'.format(dist.get_local_rank())
    checkpoint = torch.load(checkpoint_path, map_location=loc)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model'])

    model.eval()
    data_path = os.path.join(os.path.join(args.dataset_dir, 'ImageNet-1k'), 'val')
    dataset_loader = load_dataset(data_path, batch_size=128)

    for idx, (x, _) in tqdm(enumerate(dataset_loader)):
        x = x.cuda(dist.get_local_rank(), non_blocking=True) 

        with torch.no_grad():
            x_rec, _, _, _, _, _ = model.module.collect_eval_info(x)

        #for i, org_img in enumerate(x):
        #    image_name = dataset_loader.dataset.samples[idx * dataset_loader.batch_size + i][0]
        #    save_transformed_image(org_img, image_name, org_epoch_path, prefix="ORG_")

        for i, rec_img in enumerate(x_rec):
            image_name = dataset_loader.dataset.samples[idx * dataset_loader.batch_size + i][0]
            save_transformed_image(rec_img, image_name, rec_epoch_path, prefix="REC_")

if __name__ == '__main__':
    dist.initialize(fork=False, timeout=15)
    dist.barrier()
    args = config.parse_arg()
    dict_args = vars(args)
    if dist.is_local_master():
        for k, v in zip(dict_args.keys(), dict_args.values()):
            print("{0}: {1}".format(k, v))

    os.makedirs(os.path.join(args.rec_image_dir, args.rec_name), exist_ok=True)
    os.makedirs(os.path.join(args.rec_results_dir, args.rec_name), exist_ok=True)
    
    epochs = [20, 19, 18, 17, 16, 15]
    for epoch in epochs:
        main_worker(args, epoch) 


    

